"""
Authentication middleware for LandPPT using Supertokens
"""

from typing import Optional, Callable, Union
from fastapi import Request, Response, HTTPException, Depends
from fastapi.responses import RedirectResponse
from sqlalchemy.orm import Session
import logging

from supertokens_python.recipe.session.framework.fastapi import verify_session
from supertokens_python.recipe.session.interfaces import SessionContainer
from supertokens_python.recipe.session import SessionRecipe
from supertokens_python.recipe.session.syncio import create_new_session, get_session

from .auth_service import get_auth_service, AuthService
from ..database.database import get_db
from ..database.models import User
from ..core.config import app_config

logger = logging.getLogger(__name__)


class AuthMiddleware:
    """Authentication middleware using Supertokens"""
    
    def __init__(self):
        self.auth_service = get_auth_service()
        # 不需要认证的路径
        self.public_paths = {
            "/",
            "/auth/login",
            "/auth/logout",
            "/api/auth/login",
            "/api/auth/logout",
            "/api/auth/check",
            "/docs",
            "/redoc",
            "/openapi.json",
            "/static",
            "/favicon.ico"
        }
        # 不需要认证的路径前缀
        self.public_prefixes = [
            "/static/",
            "/temp/",  # 添加temp目录用于图片缓存访问
            "/api/image/view/",  # 图床图片访问无需认证
            "/api/image/thumbnail/",  # 图片缩略图访问无需认证
            "/ppt/embed/",  # 嵌入入口路由无需认证
            "/api/embed/",  # 嵌入API路由无需认证
            "/docs",
            "/redoc",
            "/openapi.json"
        ]
    
    def is_public_path(self, path: str) -> bool:
        """Check if path is public (doesn't require authentication)"""
        # Check exact matches
        if path in self.public_paths:
            return True
        
        # Check prefixes
        for prefix in self.public_prefixes:
            if path.startswith(prefix):
                return True
        
        return False
    
    async def __call__(self, request: Request, call_next: Callable):
        """Middleware function"""
        path = request.url.path
        
        # Skip authentication for public paths
        if self.is_public_path(path):
            response = await call_next(request)
            return response
        
        try:
            # Verify session using Supertokens
            session = await verify_session()(request)
            if not session:
                if path.startswith("/api/"):
                    return Response(
                        content='{"detail": "Authentication required"}',
                        status_code=401,
                        media_type="application/json"
                    )
                else:
                    return RedirectResponse(url="/auth/login", status_code=302)
            
            # Get database session
            db_gen = get_db()
            db = next(db_gen)
            
            try:
                # Get user from database using session user ID
                user = self.auth_service.get_user_by_id(db, session.get_user_id())
                
                if not user:
                    if path.startswith("/api/"):
                        return Response(
                            content='{"detail": "User not found"}',
                            status_code=401,
                            media_type="application/json"
                        )
                    else:
                        return RedirectResponse(url="/auth/login", status_code=302)
                
                # Add user to request state
                request.state.user = user
                
                # Continue with request
                response = await call_next(request)
                return response
                
            finally:
                db.close()
                
        except Exception as e:
            logger.error(f"Authentication middleware error: {e}")
            if path.startswith("/api/"):
                return Response(
                    content='{"detail": "Authentication error"}',
                    status_code=500,
                    media_type="application/json"
                )
            else:
                return RedirectResponse(url="/auth/login", status_code=302)


def get_current_user(request: Request) -> Optional[User]:
    """Get current authenticated user from request"""
    return getattr(request.state, 'user', None)


def require_auth(request: Request) -> User:
    """Dependency to require authentication"""
    user = get_current_user(request)
    if not user:
        raise HTTPException(status_code=401, detail="Authentication required")
    return user


def require_admin(request: Request) -> User:
    """Dependency to require admin privileges"""
    user = require_auth(request)
    if not user.is_admin:
        raise HTTPException(status_code=403, detail="Admin privileges required")
    return user


def get_current_user_optional(
    request: Request,
    db: Session = Depends(get_db)
) -> Optional[User]:
    """Get current user if authenticated, None otherwise"""
    try:
        session = get_session(request)
        if not session:
            return None
        
        auth_service = get_auth_service()
        return auth_service.get_user_by_id(db, session.get_user_id())
    except:
        return None


def get_current_user_required(
    request: Request,
    db: Session = Depends(get_db),
    session: SessionContainer = Depends(verify_session())
) -> User:
    """Get current user, raise exception if not authenticated"""
    auth_service = get_auth_service()
    user = auth_service.get_user_by_id(db, session.get_user_id())
    if not user:
        raise HTTPException(status_code=401, detail="Authentication required")
    return user


def get_current_admin_user(
    request: Request,
    db: Session = Depends(get_db),
    session: SessionContainer = Depends(verify_session())
) -> User:
    """Get current admin user, raise exception if not admin"""
    auth_service = get_auth_service()
    user = auth_service.get_user_by_id(db, session.get_user_id())
    if not user:
        raise HTTPException(status_code=401, detail="Authentication required")
    if not user.is_admin:
        raise HTTPException(status_code=403, detail="Admin privileges required")
    return user


def create_auth_middleware() -> AuthMiddleware:
    """Create authentication middleware instance"""
    return AuthMiddleware()


# Utility functions for templates
def is_authenticated(request: Request) -> bool:
    """Check if user is authenticated"""
    return get_current_user(request) is not None


def is_admin(request: Request) -> bool:
    """Check if user is admin"""
    user = get_current_user(request)
    return user is not None and user.is_admin


def get_user_info(request: Request) -> Optional[dict]:
    """Get user info for templates"""
    user = get_current_user(request)
    if user:
        return user.to_dict()
    return None
